"""
Code for vectorizing input based on  the top-20 features.
"""
import pickle
import csv
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer

p = '/Users/sunyambagga/Desktop/txtLAB-2/detecting-narrativity/'
with open(p+'pickles/tmv_features_lite_merged.pickle', 'rb') as f:
    TMV_FEATURES_TRAIN = pickle.load(f) # created via pickle_features.py
with open(p+'pickles/POETRY_tense_mood_voice_features_lite.pickle', 'rb') as f:
    TMV_FEATURES_POETRY = pickle.load(f) # created via pickle_features.py
with open(p+'pickles/SCIENCE-JSTOR_tense_mood_voice_features_lite.pickle', 'rb') as f:
    TMV_FEATURES_SCIENCE1 = pickle.load(f) # created via pickle_features.py
with open(p+'pickles/SCIENCE-ROYAL_tense_mood_voice_features_lite.pickle', 'rb') as f:
    TMV_FEATURES_SCIENCE2 = pickle.load(f) # created via pickle_features.py
    
TMV_FEATURES = {**TMV_FEATURES_TRAIN, **TMV_FEATURES_POETRY, **TMV_FEATURES_SCIENCE1, **TMV_FEATURES_SCIENCE2}
print("Loading TMV features pickle... Size:", len(TMV_FEATURES))
    
tmv_feature_names = ['temporality', 'temporal_order', 'setting', 'concreteness', 'saying', 'eventfulness', 'agenthood', 'agency', 'coh_seq', 'feltness', 'pct_quoted']    

top_pos_tmv_features = pd.read_csv(p+'feature-importance/PctQuoted_WithQuotationMarks_pos-tmv.csv')['feature_names'].tolist()
print("Top POS-TMV features:", top_pos_tmv_features)

BOOK_PATH_1 = '/Users/sunyambagga/Desktop/txtLAB-2/minimal-narrativity/booknlp-output-narrativity/'

# Uncomment one (out of the 3) when predicting:
# BOOK_PATH_2 = '/Users/sunyambagga/Desktop/txtLAB-2/detecting-narrativity/booknlp-output-science-jstor/'
# BOOK_PATH_2 = '/Users/sunyambagga/Desktop/txtLAB-2/detecting-narrativity/booknlp-output-science-royal/' 
BOOK_PATH_2 = '/Users/sunyambagga/Desktop/txtLAB-2/detecting-narrativity/booknlp-output-poetry/'

print('\n----\nUsing the two BookNLP paths:', BOOK_PATH_1, "\n", BOOK_PATH_2, "\n----\n")


def get_tmv_vector(fname, feats_to_consider):
    """
    Returns a feature vector for the given filename.
    It only includes the features present in the feats_to_consider list (in that order).
    """
    return [TMV_FEATURES[fname][feat] for feat in feats_to_consider]


def get_POS_str_feats(fname, pos_feats_to_consider):
    """
    Returns a string of part-of-speech tags in the given filename.
    Tags are filtered based on pos_feats_to_consider.
    """
    try:
        df = pd.read_csv(BOOK_PATH_1+fname+'/'+fname+'.tokens', delimiter='\t', quoting=csv.QUOTE_NONE)
    except:
        df = pd.read_csv(BOOK_PATH_2+fname+'/'+fname+'.tokens', delimiter='\t', quoting=csv.QUOTE_NONE)        
    df.fillna("", inplace=True)
    df['pos'] = df['pos'].str.lower()
    df = df.loc[df['pos'].isin(pos_feats_to_consider)]
    return ' '.join(df['pos'].tolist())


def top_n_model(train_x, test_x, N, return_feature_names=False):
    """
    Vectorizes the input text using top-20 POS-TMV model features.

    Parameters
    ----------
    train_x: list of train filenames
    test_x: list of test filenames
    N: number of top pos-tmv features to consider

    Returns
    -------
    X_train, X_test (sparse matrices) and list of feature_names
    """
    features_to_consider = top_pos_tmv_features[:N]
    
    if 'coherence' in features_to_consider:
        features_to_consider.remove('coherence')
        features_to_consider.append('coh_seq')
        
#     features_to_consider = ['agenthood', 'vbd', 'nn', 'vbz', 'concreteness', '-rrb-', '-lrb-', 'jj', 'in', 'prp', 'eventfulness', 'dt', 'nns', 'setting', 'agency', 'temporality', 'vbp', 'vbn', 'cc', 'feltness']
        
    tmv_feats = list(set(features_to_consider).intersection(tmv_feature_names))
    pos_feats = list(set(features_to_consider) - set(tmv_feats))
    print("Total Features: {} | TMV: {} | POS: {}".format(len(features_to_consider), tmv_feats, pos_feats))
    
    # TMV:
    tmv_train, tmv_test = [], []
    for x in train_x:
        tmv_train.append(get_tmv_vector(x, tmv_feats))
    for x in test_x:
        tmv_test.append(get_tmv_vector(x, tmv_feats))
    
    if len(pos_feats) == 0:
        if return_feature_names:
            return np.array(tmv_train), np.array(tmv_test), tmv_feats
        else:
            return np.array(tmv_train), np.array(tmv_test)
    
    # POS:
    vectorizer = CountVectorizer(ngram_range=(1,1), token_pattern=r"(?u)\b\w\w+\b|``|\"|\'", analyzer='word', encoding='utf-8')
    train_sentences, test_sentences = [], []
    for x in train_x:
        train_sentences.append(get_POS_str_feats(x, pos_feats))
    for x in test_x:
        test_sentences.append(get_POS_str_feats(x, pos_feats))
    pos_train = vectorizer.fit_transform(train_sentences)
    pos_test = vectorizer.transform(test_sentences)
    print("POS Train:", pos_train.shape, "| POS Test:", pos_test.shape, "| POS feature-columns:", vectorizer.get_feature_names())
    
    # Combine:
    combined_train = np.hstack((tmv_train, pos_train.toarray()))
    combined_test = np.hstack((tmv_test, pos_test.toarray()))
    if return_feature_names:
        return combined_train, combined_test, tmv_feats+vectorizer.get_feature_names()
    else:
        return combined_train, combined_test